/**
 * @author LKQ
 * @date 2022/2/24 15:32
 * @description
 */
public class Solution {
    public static void main(String[] args) {

    }
    public int projectionArea(int[][] grid) {
        int n = grid.length;
        // cnt1记录xy上投影面积，cnt2记录xz上投影，cnt3记录yz上投影面积
        int cnt1 = 0, cnt2 = 0, cnt3 = 0;
        for(int i = 0; i < n; i++) {
            int max2 = 0, max3 = 0;
            for(int j = 0; j < n; j++) {
                max2 = Math.max(max2, grid[i][j]);
                max3 = Math.max(max3, grid[j][i]);
                cnt1 += Math.min(1, grid[i][j]);
            }
            cnt2 += max2;
            cnt3 += max3;
        }
        return cnt1 + cnt2 + cnt3;
    }
}
